#!/usr/bin/env python
# priforgepng

from __future__ import print_function

"""Forge PNG image from raw computation."""

from array import array
from fractions import Fraction

import argparse
import re
import sys

import png


def gen_glr(x):
    """Gradient Left to Right."""
    return x


def gen_grl(x):
    return 1 - x


def gen_gtb(x, y):
    """Gradient Top to Bottom."""
    return float(y)


def gen_gbt(x, y):
    return 1.0 - float(y)


def gen_rtl(x, y):
    """Radial gradient, centred at top-left."""
    return max(1 - (float(x) ** 2 + float(y) ** 2) ** 0.5, 0.0)


def gen_rctr(x, y):
    return gen_rtl(float(x) - 0.5, float(y) - 0.5)


def gen_rtr(x, y):
    return gen_rtl(1.0 - float(x), y)


def gen_rbl(x, y):
    return gen_rtl(x, 1.0 - float(y))


def gen_rbr(x, y):
    return gen_rtl(1.0 - float(x), 1.0 - float(y))


def stripe(x, n):
    return int(x * n) & 1


def gen_vs2(x):
    return stripe(x, 2)


def gen_vs4(x):
    return stripe(x, 4)


def gen_vs10(x):
    return stripe(x, 10)


def gen_hs2(x, y):
    return stripe(float(y), 2)


def gen_hs4(x, y):
    return stripe(float(y), 4)


def gen_hs10(x, y):
    return stripe(float(y), 10)


def gen_slr(x, y):
    return stripe(x + y, 10)


def gen_srl(x, y):
    return stripe(1 + x - y, 10)


def checker(x, y, n):
    return stripe(x, n) ^ stripe(y, n)


def gen_ck8(x, y):
    return checker(x, y, 8)


def gen_ck15(x, y):
    return checker(x, y, 15)


def gen_zero(x):
    return 0


def gen_one(x):
    return 1


def yield_fun_rows(size, bitdepth, pattern):
    """
    Create a single channel (monochrome) test pattern.
    Yield each row in turn.
    """

    width, height = size

    maxval = 2 ** bitdepth - 1
    if maxval > 255:
        typecode = "H"
    else:
        typecode = "B"
    pfun = pattern_function(pattern)

    # The coordinates are an integer + 0.5,
    # effectively sampling each pixel at its centre.
    # This is morally better, and produces all 256 sample values
    # in a 256-pixel wide gradient.

    # We make a list of x coordinates here and re-use it,
    # because Fraction instances are slow to allocate.
    xs = [Fraction(x, 2 * width) for x in range(1, 2 * width, 2)]

    # The general case is a function in x and y,
    # but if the function only takes an x argument,
    # it's handled in a special case that is a lot faster.
    if n_args(pfun) == 2:
        for y in range(height):
            a = array(typecode)
            fy = Fraction(Fraction(y + 0.5), height)
            for fx in xs:
                a.append(int(round(maxval * pfun(fx, fy))))
            yield a
        return

    # For functions in x only, it's a _lot_ faster
    # to generate a single row and repeatedly yield it
    a = array(typecode)
    for fx in xs:
        a.append(int(round(maxval * pfun(x=fx))))
    for y in range(height):
        yield a
    return


def generate(args):
    """
    Create a PNG test image and write the file to stdout.

    `args` should be an argparse Namespace instance or similar.
    """

    size = args.size
    bitdepth = args.depth

    out = png.binary_stdout()

    for pattern in args.pattern:
        rows = yield_fun_rows(size, bitdepth, pattern)
        writer = png.Writer(
            size[0], size[1], bitdepth=bitdepth, greyscale=True, alpha=False
        )
        writer.write(out, rows)


def n_args(fun):
    """Number of arguments in fun's argument list."""
    return fun.__code__.co_argcount


def pattern_function(pattern):
    """From `pattern`, a string,
    return the function for that pattern.
    """

    lpat = pattern.lower()
    for name, fun in globals().items():
        parts = name.split("_")
        if parts[0] != "gen":
            continue
        if parts[1] == lpat:
            return fun


def patterns():
    """
    List the patterns.
    """

    for name, fun in globals().items():
        parts = name.split("_")
        if parts[0] == "gen":
            yield parts[1]


def dimensions(s):
    """
    Typecheck the --size option, which should be
    one or two comma separated numbers.
    Example: "64,40".
    """

    tupl = re.findall(r"\d+", s)
    if len(tupl) not in (1, 2):
        raise ValueError("%r should be width or width,height" % s)
    if len(tupl) == 1:
        tupl *= 2
    assert len(tupl) == 2
    return list(map(int, tupl))


def main(argv=None):
    if argv is None:
        argv = sys.argv
    parser = argparse.ArgumentParser(description="Forge greyscale PNG patterns")

    parser.add_argument(
        "-l", "--list", action="store_true", help="print list of patterns and exit"
    )
    parser.add_argument(
        "-d", "--depth", default=8, type=int, metavar="N", help="N bits per pixel"
    )
    parser.add_argument(
        "-s",
        "--size",
        default=[256, 256],
        type=dimensions,
        metavar="w[,h]",
        help="width and height of the image in pixels",
    )
    parser.add_argument("pattern", nargs="*", help="name of pattern")

    args = parser.parse_args(argv[1:])

    if args.list:
        for name in sorted(patterns()):
            print(name)
        return

    if not args.pattern:
        parser.error("--list or pattern is required")
    return generate(args)


if __name__ == "__main__":
    main()
